{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from simpletransformers.classification import ClassificationModel, ClassificationArgs\n",
    "from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, accuracy_score\n",
    "import torch\n",
    "\n",
    "cuda_available = torch.cuda.is_available()\n",
    "cuda_available"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#load model fine-tuned on ONQ1\n",
    "model = ClassificationModel(\n",
    "    \"bert\",\n",
    "    r\"02_output data\\models\\onq1m\",\n",
    "    use_cuda=cuda_available\n",
    ")"
   ],
   "id": "8ecb02f3701a548a",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#import bot and facebook responses to ONQ2 (\"discrimination\" question)\n",
    "df_onq2 = pd.read_excel(r\"02_output data\\onq2_balanced_sample.xlsx\")\n",
    "\n",
    "#print df characteristics\n",
    "print(df_onq2.head(5))\n",
    "print(df_onq2.shape)"
   ],
   "id": "43855df8b23b8391",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#make cross-corpus predictions on responses to ONQ2\n",
    "onq2_predictions, _ = model.predict(df_onq2['text'].tolist())\n",
    "\n",
    "#calculate precision, recall, and F1 score\n",
    "onq2_precision = precision_score(df_onq2['label'], onq2_predictions)\n",
    "onq2_recall = recall_score(df_onq2['label'], onq2_predictions)\n",
    "onq2_f1 = f1_score(df_onq2['label'], onq2_predictions)\n",
    "\n",
    "print(f\"Precision: {onq2_precision}\")\n",
    "print(f\"Recall: {onq2_recall}\")\n",
    "print(f\"F1 Score: {onq2_f1}\")"
   ],
   "id": "6c78d90488c535da",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#export df with predictions as excel file\n",
    "results_onq2 = pd.DataFrame({\n",
    "    \"Text\": df_onq2['text'].tolist(),\n",
    "    \"ID\": df_onq2['ID'].tolist(),\n",
    "    \"True Label\": df_onq2['label'].tolist(),\n",
    "    \"Predicted Label\": onq2_predictions\n",
    "})\n",
    "\n",
    "results_onq2.to_excel(r\"02_output data\\onq2_onq1m.xlsx\", index=False, engine=\"openpyxl\")"
   ],
   "id": "fb90ad42d9752dda",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#import bot and facebook responses to ONQ3 (\"final comment\" question)\n",
    "df_onq3 = pd.read_excel(r\"02_output data\\onq3_balanced_sample.xlsx\")\n",
    "\n",
    "\n",
    "print(df_onq3.head(5))\n",
    "print(df_onq3.shape)"
   ],
   "id": "cb36ce3370388498",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#make cross-corpus predictions on responses to ONQ3\n",
    "onq3_predictions, _ = model.predict(df_onq3['text'].tolist())\n",
    "\n",
    "#calculate precision, recall, and F1 score\n",
    "onq3_precision = precision_score(df_onq3['label'], onq3_predictions)\n",
    "onq3_recall = recall_score(df_onq3['label'], onq3_predictions)\n",
    "onq3_f1 = f1_score(df_onq3['label'], onq3_predictions)\n",
    "\n",
    "print(f\"Precision: {onq3_precision}\")\n",
    "print(f\"Recall: {onq3_recall}\")\n",
    "print(f\"F1 Score: {onq3_f1}\")"
   ],
   "id": "49ecadec911f0023",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#export df with predictions as excel file\n",
    "results_onq3 = pd.DataFrame({\n",
    "    \"Text\": df_onq3['text'].tolist(),\n",
    "    \"ID\": df_onq3['ID'].tolist(),\n",
    "    \"True Label\": df_onq3['label'].tolist(),\n",
    "    \"Predicted Label\": onq3_predictions\n",
    "})\n",
    "\n",
    "results_onq3.to_excel(r\"02_output data\\onq3_onq1m.xlsx\", index=False, engine=\"openpyxl\")"
   ],
   "id": "f70c110a0d7c28fc",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
